{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pprint\n",
    "import sys\n",
    "if \"../\" not in sys.path:\n",
    "  sys.path.append(\"../\") \n",
    "from lib.envs.gridworld import GridworldEnv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pp = pprint.PrettyPrinter(indent=2)\n",
    "env = GridworldEnv()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Taken from Policy Evaluation Exercise!\n",
    "\n",
    "def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
    "    \"\"\"\n",
    "    Evaluate a policy given an environment and a full description of the environment's dynamics.\n",
    "    \n",
    "    Args:\n",
    "        policy: [S, A] shaped matrix representing the policy.\n",
    "        env: OpenAI env. env.P represents the transition probabilities of the environment.\n",
    "            env.P[s][a] is a list of transition tuples (prob, next_state, reward, done).\n",
    "            env.nS is a number of states in the environment. \n",
    "            env.nA is a number of actions in the environment.\n",
    "        theta: We stop evaluation once our value function change is less than theta for all states.\n",
    "        discount_factor: Gamma discount factor.\n",
    "    \n",
    "    Returns:\n",
    "        Vector of length env.nS representing the value function.\n",
    "    \"\"\"\n",
    "    # Start with a random (all 0) value function\n",
    "    V = np.zeros(env.nS)\n",
    "    while True:\n",
    "        delta = 0\n",
    "        # For each state, perform a \"full backup\"\n",
    "        for s in range(env.nS):\n",
    "            v = 0\n",
    "            # Look at the possible next actions\n",
    "            for a, action_prob in enumerate(policy[s]):\n",
    "                # For each action, look at the possible next states...\n",
    "                for  prob, next_state, reward, done in env.P[s][a]:\n",
    "                    # Calculate the expected value\n",
    "                    v += action_prob * prob * (reward + discount_factor * V[next_state])\n",
    "            # How much our value function changed (across any states)\n",
    "            delta = max(delta, np.abs(v - V[s]))\n",
    "            V[s] = v\n",
    "        # Stop evaluating once our value function change is below a threshold\n",
    "        if delta < theta:\n",
    "            break\n",
    "    return np.array(V)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):\n",
    "    \"\"\"\n",
    "    Policy Improvement Algorithm. Iteratively evaluates and improves a policy\n",
    "    until an optimal policy is found.\n",
    "    \n",
    "    Args:\n",
    "        env: The OpenAI environment.\n",
    "        policy_eval_fn: Policy Evaluation function that takes 3 arguments:\n",
    "            policy, env, discount_factor.\n",
    "        discount_factor: gamma discount factor.\n",
    "        \n",
    "    Returns:\n",
    "        A tuple (policy, V). \n",
    "        policy is the optimal policy, a matrix of shape [S, A] where each state s\n",
    "        contains a valid probability distribution over actions.\n",
    "        V is the value function for the optimal policy.\n",
    "        \n",
    "    \"\"\"\n",
    "\n",
    "    def one_step_lookahead(state, V):\n",
    "        \"\"\"\n",
    "        Helper function to calculate the value for all action in a given state.\n",
    "        \n",
    "        Args:\n",
    "            state: The state to consider (int)\n",
    "            V: The value to use as an estimator, Vector of length env.nS\n",
    "        \n",
    "        Returns:\n",
    "            A vector of length env.nA containing the expected value of each action.\n",
    "        \"\"\"\n",
    "        A = np.zeros(env.nA)\n",
    "        for a in range(env.nA):\n",
    "            for prob, next_state, reward, done in env.P[state][a]:\n",
    "                A[a] += prob * (reward + discount_factor * V[next_state])\n",
    "        return A\n",
    "    \n",
    "    # Start with a random policy\n",
    "    policy = np.ones([env.nS, env.nA]) / env.nA\n",
    "    \n",
    "    while True:\n",
    "        # Evaluate the current policy\n",
    "        V = policy_eval_fn(policy, env, discount_factor)\n",
    "        \n",
    "        # Will be set to false if we make any changes to the policy\n",
    "        policy_stable = True\n",
    "        \n",
    "        # For each state...\n",
    "        for s in range(env.nS):\n",
    "            # The best action we would take under the current policy\n",
    "            chosen_a = np.argmax(policy[s])\n",
    "            \n",
    "            # Find the best action by one-step lookahead\n",
    "            # Ties are resolved arbitarily\n",
    "            action_values = one_step_lookahead(s, V)\n",
    "            best_a = np.argmax(action_values)\n",
    "            \n",
    "            # Greedily update the policy\n",
    "            if chosen_a != best_a:\n",
    "                policy_stable = False\n",
    "            policy[s] = np.eye(env.nA)[best_a]\n",
    "        \n",
    "        # If the policy is stable we've found an optimal policy. Return it\n",
    "        if policy_stable:\n",
    "            return policy, V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Policy Probability Distribution:\n",
      "[[1. 0. 0. 0.]\n",
      " [0. 0. 0. 1.]\n",
      " [0. 0. 0. 1.]\n",
      " [0. 0. 1. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [0. 0. 1. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [0. 1. 0. 0.]\n",
      " [0. 0. 1. 0.]\n",
      " [1. 0. 0. 0.]\n",
      " [0. 1. 0. 0.]\n",
      " [0. 1. 0. 0.]\n",
      " [1. 0. 0. 0.]]\n",
      "\n",
      "Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
      "[[0 3 3 2]\n",
      " [0 0 0 2]\n",
      " [0 0 1 2]\n",
      " [0 1 1 0]]\n",
      "\n",
      "Value Function:\n",
      "[ 0. -1. -2. -3. -1. -2. -3. -2. -2. -3. -2. -1. -3. -2. -1.  0.]\n",
      "\n",
      "Reshaped Grid Value Function:\n",
      "[[ 0. -1. -2. -3.]\n",
      " [-1. -2. -3. -2.]\n",
      " [-2. -3. -2. -1.]\n",
      " [-3. -2. -1.  0.]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "policy, v = policy_improvement(env)\n",
    "print(\"Policy Probability Distribution:\")\n",
    "print(policy)\n",
    "print(\"\")\n",
    "\n",
    "print(\"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\")\n",
    "print(np.reshape(np.argmax(policy, axis=1), env.shape))\n",
    "print(\"\")\n",
    "\n",
    "print(\"Value Function:\")\n",
    "print(v)\n",
    "print(\"\")\n",
    "\n",
    "print(\"Reshaped Grid Value Function:\")\n",
    "print(v.reshape(env.shape))\n",
    "print(\"\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the value function\n",
    "expected_v = np.array([ 0, -1, -2, -3, -1, -2, -3, -2, -2, -3, -2, -1, -3, -2, -1,  0])\n",
    "np.testing.assert_array_almost_equal(v, expected_v, decimal=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
